import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from .utils import cos_sim
import numpy as np

class ContrastiveTensionLoss(nn.Module):
    """
        This loss expects as input a batch consisting of multiple mini-batches of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_{K+1}, p_{K+1})
        where p_1 = a_1 = a_2 = ... a_{K+1} and p_2, p_3, ..., p_{K+1} are expected to be different from p_1 (this is done via random sampling).
        The corresponding labels y_1, y_2, ..., y_{K+1} for each mini-batch are assigned as: y_i = 1 if i == 1 and y_i = 0 otherwise.
        In other words, K represent the number of negative pairs and the positive pair is actually made of two identical sentences. The data generation
        process has already been implemented in readers/ContrastiveTensionReader.py
        For tractable optimization, two independent encoders ('model1' and 'model2') are created for encoding a_i and p_i, respectively. For inference,
        only model2 are used, which gives better performance. The training objective is binary cross entropy.
        For more information, see: https://openreview.net/pdf?id=Ov_sMNau-PF

    """
    def __init__(self):
        super(ContrastiveTensionLoss, self).__init__()
        self.criterion = nn.BCEWithLogitsLoss(reduction='sum')

    def forward(self, reps , labels: Tensor):
        reps_1, reps_2 = reps
        # reps_1 = self.model1(sentence_features1)['sentence_embedding']  # (bsz, hdim)
        # reps_2 = self.model2(sentence_features2)['sentence_embedding']

        sim_scores = torch.matmul(reps_1[:,None], reps_2[:,:,None]).squeeze(-1).squeeze(-1)  # (bsz,) dot product, i.e. S1S2^T

        loss = self.criterion(sim_scores, labels.type_as(sim_scores))
        return loss


class ContrastiveTensionLossInBatchNegatives(nn.Module):
    def __init__(self,  scale: float = 20.0, similarity_fct = cos_sim):
        """
        :param model: SentenceTransformer model
        """
        super(ContrastiveTensionLossInBatchNegatives, self).__init__()
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        #self.scale = scale
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(scale))

    def forward(self, reps, labels: Tensor):
        sentence_features1, sentence_features2 = reps
        # embeddings_a = self.model1(sentence_features1)['sentence_embedding']  # (bsz, hdim)
        # embeddings_b = self.model2(sentence_features2)['sentence_embedding']

        scores = self.similarity_fct(sentence_features1, sentence_features2) * self.logit_scale.exp()  #self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)
        return (self.cross_entropy_loss(scores, labels) + self.cross_entropy_loss(scores.t(), labels))/2

################# CT Data Loader #################
# For CT, we need batches in a specific format
# In each batch, we have one positive pair (i.e. [sentA, sentA]) and 7 negative pairs (i.e. [sentA, sentB]).
# To achieve this, we create a custom DataLoader that produces batches with this property

# class ContrastiveTensionDataLoader:
#     def __init__(self, sentences, batch_size, pos_neg_ratio=8):
#         self.sentences = sentences
#         self.batch_size = batch_size
#         self.pos_neg_ratio = pos_neg_ratio
#         self.collate_fn = None
#
#         if self.batch_size % self.pos_neg_ratio != 0:
#             raise ValueError(f"ContrastiveTensionDataLoader was loaded with a pos_neg_ratio of {pos_neg_ratio} and a batch size of {batch_size}. The batch size must be devisable by the pos_neg_ratio")
#
#     def __iter__(self):
#         random.shuffle(self.sentences)
#         sentence_idx = 0
#         batch = []
#
#         while sentence_idx + 1 < len(self.sentences):
#             s1 = self.sentences[sentence_idx]
#             if len(batch) % self.pos_neg_ratio > 0:    #Negative (different) pair
#                 sentence_idx += 1
#                 s2 = self.sentences[sentence_idx]
#                 label = 0
#             else:   #Positive (identical pair)
#                 s2 = self.sentences[sentence_idx]
#                 label = 1
#
#             sentence_idx += 1
#             batch.append(InputExample(texts=[s1, s2], label=label))
#
#             if len(batch) >= self.batch_size:
#                 yield self.collate_fn(batch) if self.collate_fn is not None else batch
#                 batch = []
#
#     def __len__(self):
#         return math.floor(len(self.sentences)/(2*self.batch_size))